import pandas as pd
from sklearn import metrics
import textdistance
import itertools
import numpy as np
import pickle as pkl
import matplotlib.pyplot as plt
import random

from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import GridSearchCV, train_test_split
from sklearn.metrics import accuracy_score
from sklearn.preprocessing import FunctionTransformer
from sklearn.pipeline import Pipeline
from sklearn.model_selection import KFold
from sklearn.metrics import f1_score

def train(X, y):
    pipeline = Pipeline([('forest', RandomForestClassifier(random_state=40))])

    parameters = {
        'forest__max_depth': [5, 10, 20],
        'forest__n_estimators': [10, 50, 100]
        }

    GSCV = GridSearchCV(cv = 5,
                       estimator = pipeline,
                       param_grid = parameters,
                       verbose = 0)

    model = GSCV.fit(X, y)

    print("Best parameters:")
    print(model.best_params_)
    print("done training")
    feat_importances = model.best_estimator_.named_steps["forest"].feature_importances_


    return model, feat_importances



def distance_metric(metric,x):
    return metric(x[0],x[1])

def precompute_distances(amicus, bonica, dump_duplicates=True,combination_type='dif_lists'):
    
    metrics = {'cosine': textdistance.cosine, 'jaccard': textdistance.jaccard, 
               'levenshtein': textdistance.levenshtein,'lcsstr':textdistance.lcsstr.distance, 'overlap': textdistance.overlap}
    
    if combination_type == 'same_list':
        combos = list(itertools.combinations(amicus,2))
        base_df = pd.DataFrame({'amicus': [x for (x,y) in combos], 'bonica': [y for (x,y) in combos]})
        
    elif combination_type == 'dif_lists':
        combos = list(itertools.product(amicus, bonica))
        base_df = pd.DataFrame({'amicus': [x for (x,y) in combos], 'bonica': [y for (x,y) in combos]})
        
    else:
        base_df = pd.DataFrame({'amicus': amicus, 'bonica': bonica})
        
    if dump_duplicates:
        base_df = base_df.loc[base_df['amicus'] != base_df['bonica']]

    for metric in metrics:
        base_df[metric] = base_df.apply((lambda x: distance_metric(metrics[metric], x)), axis=1)

    return base_df

def shuffle(best_matches):
    pos_matches = best_matches.loc[best_matches['label'] == 1]
    n = pos_matches.shape[0]
    if n == 0:
        return best_matches
    neg_matches = best_matches.loc[best_matches['label'] == 0]
    neg_matches = neg_matches[['amicus','bonica','cosine','jaccard','lcsstr','levenshtein','overlap','label']]
    
    shuffled = precompute_distances(pos_matches['amicus'], pos_matches['bonica'], dump_duplicates=False, combination_type='dif_lists')
    shuffled['label'] = [1*(x%(n+1) == 0) for x in range(shuffled.shape[0])]
    
    return pd.concat([shuffled, neg_matches], axis=0)

def ask_about_matches(match_pairs):
    matches = []
    count = 0
    for pair in match_pairs:
        count += 1
        print(count)
        match = ''
        while match != 'y' and match != 'n':
            print(pair)
            match = input("Do these match (y or n): ")

        if match == 'y':
            matches.append(1)
        elif match == 'n':
            matches.append(0)

    return matches

def get_predictions(model, test_set, colname):
    X = test_set[['cosine', 'jaccard', 'levenshtein', 'lcsstr', 'overlap']]
    test_set[colname] = model.predict_proba(X)[:,1]
    test_set = test_set.sort_values(colname, ascending=False)
    return test_set

def amicus_bonica_iteration(num_bonicas, num_labeled_pairs, model, amicus, bonica):
    print('selecting bonicas.')
    bonica_sample = np.random.choice(bonica, size=num_bonicas, replace=False)
    print('computing distance metrics.')
    pairs_df = precompute_distances(amicus, bonica_sample)
    print('calculating scores')
    pairs_df = get_predictions(model, pairs_df, 'basic_score')
    best_matches = pairs_df.head(num_labeled_pairs)
    pairs = np.array(best_matches[['amicus', 'bonica']])
    ground_truth = ask_about_matches(pairs)
    best_matches['label'] = ground_truth
    return best_matches